package com.yifeng.repo.controller.traffic.in.processor.manager;

import com.yifeng.repo.base.utils.common.Snowflake;
import com.yifeng.repo.controller.traffic.in.TrafficInManager;
import com.yifeng.repo.controller.traffic.in.processor.worker.TrafficInWorker;
import com.yifeng.repo.controller.traffic.in.processor.model.TrafficInPolicy;
import com.yifeng.repo.controller.traffic.in.processor.model.TrafficInStats;
import com.yifeng.repo.controller.traffic.in.processor.worker.TrafficInWorkerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

/**
 * Created by daibing on 2021/12/4.
 */
public class TrafficInStatsManager {
    private static final Logger LOGGER = LoggerFactory.getLogger(TrafficInStatsManager.class);
    private static final ConcurrentMap<String, TrafficInStats> policyCode2Stats = new ConcurrentHashMap<>();

    /**
     * 根据策略编号检查许可
     */
    public static void checkPermit(String policyCode) {
        TrafficInWorker worker = TrafficInWorkerFactory.get(TrafficInManager.getWorkerClassName());
        TrafficInPolicy policy = worker.getPolicy(policyCode);
        if (policy == null) {
            return;
        }
        boolean success = acquire(policy.getCode(), policy.getPeriodMinutes(), policy.getMaximumLimit(), policy.getPreApplyNum(), 1);
        if (!success) {
            throw new RuntimeException("check permit by policyCode failed: policyCode=" + policyCode);
        }
    }

    /**
     * 根据方法上下文检查许可
     */
    public static void checkPermit(String clazz, String method, int argNumber, Map<String, Object> argMap) {
        TrafficInWorker worker = TrafficInWorkerFactory.get(TrafficInManager.getWorkerClassName());
        TrafficInPolicy policy = worker.getPolicy(clazz, method, argNumber, argMap);
        if (policy == null) {
            return;
        }
        boolean success = acquire(policy.getCode(), policy.getPeriodMinutes(), policy.getMaximumLimit(), policy.getPreApplyNum(), 1);
        if (!success) {
            throw new RuntimeException(
                    String.format("check permit by context failed: clazz=%s, method=%s, argNumber=%s, policyCode=%s", clazz, method, argMap, policy.getCode())
            );
        }
    }

    /**
     * 根据策略申请授权
     * @param trafficInPolicyCode 策略编号
     * @param periodMinutes 周期分钟
     * @param periodMaximumLimit 周期内最大授权
     * @param preApplyNum 预先申请数量
     * @param applyNum 本次申请数量
     * @return 是否允许授权
     */
    public static boolean acquire(String trafficInPolicyCode, int periodMinutes, int periodMaximumLimit, int preApplyNum, int applyNum) {
        // 1. 获取worker
        TrafficInWorker worker = TrafficInWorkerFactory.get(TrafficInManager.getWorkerClassName());

        // 2. 构造流量控制统计数据：floorBoundByPeriod 周期地板值用来确定周期边界
        TrafficInStats stats = policyCode2Stats.get(trafficInPolicyCode);
        long floorBoundByPeriod = System.currentTimeMillis() / (periodMinutes * 60 * 1000L);
        if (stats == null || stats.getFloorBoundByPeriod() != floorBoundByPeriod
                || stats.getPeriodMinutes() != periodMinutes || stats.getMaximumLimit() != periodMaximumLimit) {
            synchronized (policyCode2Stats) {
                stats = policyCode2Stats.get(trafficInPolicyCode);
                if (stats == null || stats.getFloorBoundByPeriod() != floorBoundByPeriod
                        || stats.getPeriodMinutes() != periodMinutes || stats.getMaximumLimit() != periodMaximumLimit) {
                    stats = buildTrafficInStats(trafficInPolicyCode, floorBoundByPeriod, periodMinutes, periodMaximumLimit);
                    worker.insertOrUpdate(stats);
                    policyCode2Stats.put(trafficInPolicyCode, stats);
                }
            }
        }

        // 3. 检查当前是否有可用的指标：如果还没有用完授权并且持有授权小于申请数量，就申请一批指标（按照预申请数量）
        if (!stats.isUsedUp() && stats.getStatsApplyTotal() - stats.getStatsUsedTotal() < applyNum) {
            synchronized (policyCode2Stats) {
                if (stats.getStatsApplyTotal() - stats.getStatsUsedTotal() < applyNum) {
                    boolean success = worker.acquireBatch(trafficInPolicyCode, stats.getFloorBoundByPeriod(), preApplyNum);
                    if (!success) {
                        // 本周期内使用完了授权
                        stats.setUsedUp(true);
                        return false;
                    }
                    stats.setStatsApplyTotal(preApplyNum);
                    stats.setStatsUsedTotal(0);
                }
            }
        }

        // 4.检查当前是否还有可用的指标，如果还足够就分配指标
        synchronized (policyCode2Stats) {
            if (stats.getStatsApplyTotal() - stats.getStatsUsedTotal() >= applyNum) {
                stats.setStatsUsedTotal(stats.getStatsUsedTotal() + applyNum);
                LOGGER.info("traffic in limiter acquire ok: trafficInPolicyCode={}, floorBoundByPeriod={}, applyTotal={}, usedTotal={}",
                        trafficInPolicyCode, stats.getFloorBoundByPeriod(), stats.getStatsApplyTotal(), stats.getStatsUsedTotal());
                return true;
            }
        }
        return false;
    }

    private static TrafficInStats buildTrafficInStats(String policyCode, long floorBoundByPeriod, int periodMinutes, int periodMaximumLimit) {
        TrafficInStats stats = new TrafficInStats();
        stats.setId(Snowflake.get().nextId());
        stats.setTrafficInPolicyCode(policyCode);
        stats.setPeriodMinutes(periodMinutes);
        stats.setMaximumLimit(periodMaximumLimit);
        stats.setFloorBoundByPeriod(floorBoundByPeriod);
        stats.setStatsApplyTotal(0);
        stats.setStatsUsedTotal(0);
        stats.setUsedUp(false);
        return stats;
    }

    private static boolean isBlank(String s) {
        return s == null || s.trim().length() == 0;
    }

}
